from channels.generic.websocket import WebsocketConsumer,AsyncWebsocketConsumer
from django.http.request import QueryDict
from websocket.ssh import SSH
from websocket.tasks import task_add
from asset.models import *
from utils.custom_log import log_start
logger = log_start('ws')
import os
import json
import base64
import re

class AnsibleConsumer(WebsocketConsumer):
    # 当Websocket创建连接时
    def connect(self):
        # websocket保持连接
        self.accept()  
        
        ## 获取前端传入数据
        query_string = self.scope.get('query_string')
        args = QueryDict(query_string=query_string, encoding='utf-8')
        logger.info (f"QueryDict:{args}")

    # 当Websocket发生断开连接时
    def disconnect(self, code):
        pass  

    # 当Websocket接收到消息时
    def receive(self, text_data=None, bytes_data=None):
        print(text_data)  # 打印收到的数据





class SSHConsumer(WebsocketConsumer): 

    message = {'status': 0, 'message': None}

    """
    1. connect方法在连接建立时触发
    2. disconnect在连接关闭时触发
    3. receive方法会在收到消息后触发

    status:
        0: ssh 连接正常, websocket 正常
        1: 发生未知错误, 关闭 ssh 和 websocket 连接

    message:
        status 为 1 时, message 为具体的错误信息
        status 为 0 时, message 为 ssh 返回的数据, 前端页面将获取 ssh 返回的数据并写入终端页面
    """

    def connect(self):
        """
        打开 websocket 连接, 通过前端传入的参数尝试连接 ssh 主机
        :return:
        """
        self.accept()
        query_string = self.scope.get('query_string')
        args = QueryDict(query_string=query_string, encoding='utf-8')
        logger.info (f"QueryDict:{args}")
        id = args.get('id') # 获取row id
        HostInstance = Host.objects.get(id=int(id))


        if HostInstance.connect_type == 1:
            host = HostInstance.floatingip
            matchObj=re.match(r'(.*.?\d)?',host)
            if matchObj:
                host =  matchObj.group(1)
        else:
            host = HostInstance.ip
        
        port = HostInstance.ssh_id.ssh_port
        user = HostInstance.ssh_id.ssh_user
        password = HostInstance.ssh_id.password
        private = HostInstance.ssh_id.private

        width = int(args.get('width'))
        height = int(args.get('height'))
        

        ssh_connect_dict = {
            'host': host,
            'user': user,
            'port': port,
            'timeout': 5,
            'width': width,
            'height': height,
            'password': password,
            'ssh_key':private
        }
        # logger.info(f'{ssh_connect_dict}')
        
        self.ssh = SSH(websocker=self, message=self.message)
        self.ssh.connect(**ssh_connect_dict)
        logger.info('connect')
        
            

    def disconnect(self, close_code):
        try:
            if close_code == 3001:
                logger.info("close_code=3001")
                # self.ssh.close()
                # pass
        except Exception as  e:
            pass
        finally:
            # 过滤点结果中的颜色字符
            # res = re.sub('(\[\d{2};\d{2}m|\[0m)', '', self.ssh.res)
            logger.info('Websocket disconnect')
 


    def receive(self, text_data=None, bytes_data=None):
        if text_data is None:
            self.ssh.django_bytes_to_ssh(bytes_data)
        else:
            data = json.loads(text_data)
            
            if type(data) == dict:
                status = data['status']
                if status == 0:
                    data = data['data']
                    logger.info(f'接收到前端send::{data}',)
                    self.ssh.shell(data)
                    task_add.delay()
                else:
                    logger.error('调整窗口中')
                    cols = data['cols']
                    rows = data['rows']
                    self.ssh.resize_pty(cols=cols, rows=rows)

    
    def send_message(self, event):
	#自定义函数，在celery任务中使用channels_layers会用到
        self.send(text_data=json.dumps(event))